# Copyright 2021 Ant Group Co., Ltd.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

load("@llvm-project//mlir:tblgen.bzl", "gentbl_cc_library")
load("//bazel:spu.bzl", "spu_cc_library")

package(
    default_visibility = ["//visibility:public"],
    licenses = ["notice"],
)

gentbl_cc_library(
    name = "pphlo_pass_inc_gen",
    tbl_outs = [
        (
            ["-gen-pass-decls"],
            "passes.h.inc",
        ),
    ],
    tblgen = "@llvm-project//mlir:mlir-tblgen",
    td_file = "passes.td",
    deps = [
        "@llvm-project//mlir:PassBaseTdFiles",
    ],
)

spu_cc_library(
    name = "pass_details",
    hdrs = [
        "pass_details.h",
    ],
    visibility = [
        "//visibility:private",  # This target is a private detail of pass implementations
    ],
    deps = [
        ":pphlo_pass_inc_gen",
        "@llvm-project//mlir:Pass",
    ],
)

spu_cc_library(
    name = "map_stablehlo_to_pphlo_op",
    hdrs = ["map_stablehlo_to_pphlo_op.h"],
    visibility = [
        "//visibility:private",  # This target is a private detail of ops map
    ],
    deps = [
        "//libspu/dialect/pphlo:dialect",
        "@stablehlo//:stablehlo_ops",
    ],
)

spu_cc_library(
    name = "value_visibility_map",
    srcs = ["value_visibility_map.cc"],
    hdrs = ["value_visibility_map.h"],
    deps = [
        "//libspu/core:prelude",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
    ],
)

spu_cc_library(
    name = "visibility_inference",
    srcs = ["visibility_inference.cc"],
    hdrs = ["visibility_inference.h"],
    deps = [
        ":value_visibility_map",
        "//libspu/core:prelude",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@stablehlo//:stablehlo_ops",
    ],
)

spu_cc_library(
    name = "expand_secret_gather",
    srcs = ["expand_secret_gather.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "hlo_legalize_to_pphlo",
    srcs = ["hlo_legalize_to_pphlo.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":map_stablehlo_to_pphlo_op",
        ":pass_details",
        ":visibility_inference",
        "//libspu/compiler/common:compilation_context",
        "//libspu/core:prelude",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@stablehlo//:stablehlo_ops",
    ],
)

spu_cc_library(
    name = "decompose_comparison",
    srcs = ["decompose_comparison.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "lower_conversion_cast",
    srcs = ["lower_conversion_cast.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "decompose_minmax",
    srcs = ["decompose_minmax.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "reduce_truncation",
    srcs = ["reduce_truncation.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "lower_mixed_type_op",
    srcs = ["lower_mixed_type_op.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "optimize_maxpool",
    srcs = ["optimize_maxpool.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "optimize_select",
    srcs = ["optimize_select.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "optimize_sqrt_plus_eps",
    srcs = ["optimize_sqrt_plus_eps.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "rewrite_div_sqrt_patterns",
    srcs = ["rewrite_div_sqrt_patterns.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "optimize_denominator_with_broadcast",
    srcs = ["optimize_denominator_with_broadcast.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "insert_deallocation",
    srcs = ["insert_deallocation.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "sort_lowering",
    srcs = ["sort_lowering.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "convert_push_down",
    srcs = ["convert_push_down.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "partial_sort_to_topk",
    srcs = ["partial_sort_to_gc_topk.cc"],
    hdrs = ["passes.h"],
    deps = [
        ":pass_details",
        "//libspu/dialect/pphlo:dialect",
        "@llvm-project//mlir:IR",
        "@llvm-project//mlir:TransformUtils",
    ],
)

spu_cc_library(
    name = "all_passes",
    hdrs = ["register_passes.h"],
    deps = [
        ":convert_push_down",
        ":decompose_comparison",
        ":decompose_minmax",
        ":expand_secret_gather",
        ":hlo_legalize_to_pphlo",
        ":insert_deallocation",
        ":lower_conversion_cast",
        ":lower_mixed_type_op",
        ":optimize_denominator_with_broadcast",
        ":optimize_maxpool",
        ":optimize_select",
        ":optimize_sqrt_plus_eps",
        ":partial_sort_to_topk",
        ":reduce_truncation",
        ":rewrite_div_sqrt_patterns",
        ":sort_lowering",
    ],
)
